# Copyright (c) 2018-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from __future__ import print_function

import os
import argparse
import wandb
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.test_utils import LeNet5
from advertorch_examples.utils import get_mnist_train_loader
from advertorch_examples.utils import get_mnist_test_loader
from advertorch_examples.utils import TRAINED_MODEL_PATH

'''
adversarially train a robust model on MNIST.
在mnist数据集上, 进行对抗训练得到鲁棒模型
'''
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train MNIST')
    # 设置随机种子
    parser.add_argument('--seed', default=0, type=int)
    # 标准训练【默认】，对抗训练
    parser.add_argument('--mode', default="cln", help="cln | adv")
    # 训练集的batch_size = 50
    parser.add_argument('--train_batch_size', default=50, type=int)
    # 测试集的batch_size = 1000
    parser.add_argument('--test_batch_size', default=1000, type=int)
    # 日志输出的间隔 200个batch输出一次
    parser.add_argument('--log_interval', default=200, type=int)
    args = parser.parse_args()

    # 启动wandb
    print('Start wandb, view at https://wandb.ai/')
    wandb.init(project='tutorial_train_mnist', name=time.strftime('%m%d%H%M%S'))
    train_log = {}  # 训练日志
    val_log = {}    # 测试日志

    torch.manual_seed(args.seed)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    # 待保存的参数文件名（不同训练模式）
    if args.mode == "cln":
        flag_advtrain = False
        nb_epoch = 10   # epoch数
        model_filename = "mnist_lenet5_clntrained.pt"
    elif args.mode == "adv":
        flag_advtrain = True
        nb_epoch = 90   # epoch数
        model_filename = "mnist_lenet5_advtrained.pt"
    else:
        raise

    # 加载数据集DataLoader
    train_loader = get_mnist_train_loader(
        batch_size=args.train_batch_size, shuffle=True)
    test_loader = get_mnist_test_loader(
        batch_size=args.test_batch_size, shuffle=False)

    # 加载模型
    model = LeNet5()
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)  # 优化器

    # 对抗模型  -- 设计原理
    if flag_advtrain:
        from advertorch.attacks import LinfPGDAttack
        adversary = LinfPGDAttack(
            model,          # 模型
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),  # 损失函数
            eps=0.3,        # 最大扰动
            nb_iter=40,     # 迭代数
            eps_iter=0.01,  # 攻击步长
            rand_init=True, # 随机初始化
            clip_min=0.0,   # 扰动的最小值
            clip_max=1.0,   # 扰动的最大值
            targeted=False) # 是否是针对性攻击

    # 模型训练
    for epoch in range(nb_epoch):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            ori = data    # 原数据
            # 对每批数据生成对抗样本
            if flag_advtrain:
                # when performing attack, the model needs to be in eval mode
                # also the parameters should NOT be accumulating gradients
                with ctx_noparamgrad_and_eval(model):       # 不更新梯度
                    data = adversary.perturb(data, target)  # 对抗样本

            # 更新优化器
            optimizer.zero_grad()
            # 将对抗样本放入，计算对抗后的损失
            output = model(data)
            loss = F.cross_entropy(
                output, target, reduction='elementwise_mean')
            loss.backward()
            optimizer.step()

            # log记录：200个batch输出一次  ------修改为进度条记录
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx *
                    len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
                
                # 写入wandb【200个batch输出一次】
                train_log['train/epoch'] = epoch
                train_log['train/loss'] = loss.item()
                train_log['epoch/data_rate'] = 100. * batch_idx / len(train_loader)  # 当前epoch的数据量
                wandb.log(train_log)


        # 模型评估
        model.eval()
        test_clnloss = 0    # 标准训练的损失
        clncorrect = 0      # 标准训练的正确率

        if flag_advtrain:
            test_advloss = 0    # 对抗训练的损失
            advcorrect = 0      # 对抗训练的正确率

        for clndata, target in test_loader:
            # 标准训练
            clndata, target = clndata.to(device), target.to(device)
            with torch.no_grad():
                output = model(clndata)
            test_clnloss += F.cross_entropy(
                output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            clncorrect += pred.eq(target.view_as(pred)).sum().item()

            # 对抗训练
            if flag_advtrain:
                advdata = adversary.perturb(clndata, target)   # 用对抗模型生成对抗样本
                with torch.no_grad():
                    output = model(advdata)                    # 用对抗样本替换原数据，代入模型计算损失和准确率
                test_advloss += F.cross_entropy(
                    output, target, reduction='sum').item()
                pred = output.max(1, keepdim=True)[1]
                advcorrect += pred.eq(target.view_as(pred)).sum().item()

        # 输出平均损失，样本量，准确率
        # ---模型输出修改，增加wandb
        test_clnloss /= len(test_loader.dataset)
        print('\nTest set: avg cln loss: {:.4f},'
              ' cln acc: {}/{} ({:.0f}%)\n'.format(
                  test_clnloss, clncorrect, len(test_loader.dataset),
                  100. * clncorrect / len(test_loader.dataset)))
        # 验证集的结果
        val_log['test/avg_cln_loss'] = test_clnloss
        val_log['test/cln_correct_numbers'] = clncorrect
        val_log['test/dataset_numbers'] = len(test_loader.dataset)   # 测试集数量
        val_log['test/cln_acc'] = 100. * clncorrect / len(test_loader.dataset)

        if flag_advtrain:
            test_advloss /= len(test_loader.dataset)
            print('Test set: avg adv loss: {:.4f},'
                  ' adv acc: {}/{} ({:.0f}%)\n'.format(
                      test_advloss, advcorrect, len(test_loader.dataset),
                      100. * advcorrect / len(test_loader.dataset)))
            # log
            val_log['test/avg_adv_loss'] = test_advloss
            val_log['test/adv_correct_numbers'] = advcorrect
            val_log['test/adv_acc'] = 100. * advcorrect / len(test_loader.dataset)

        # 将验证日志写入wandb
        wandb.log(val_log)

    # 保存模型训练好的参数
    torch.save(
        model.state_dict(),
        os.path.join(TRAINED_MODEL_PATH, model_filename))
    
    torch.save(
        model.state_dict(),
        os.path.join('./data', model_filename))
